import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.special import sph_harm

EPS = 1e-16

def associated_legendre_polynomials(L, x):
    P = [torch.ones_like(x) for _ in range((L+1)*L//2)]
    for l in range(1, L):
        P[(l+3)*l//2] = - np.sqrt((2*l-1)/(2*l)) * torch.sqrt(1-x**2) * P[(l+2)*(l-1)//2]
    for m in range(L-1):
        P[(m+2)*(m+1)//2+m] = x * np.sqrt(2*m+1) * P[(m+1)*m//2+m]
        for l in range(m+2, L):
            P[(l+1)*l//2+m] = ((2*l-1)*x*P[l*(l-1)//2 + m]/np.sqrt((l**2-m**2)) - P[(l-1)*(l-2)//2+m]*np.sqrt(((l-1)**2-m**2)/(l**2-m**2)))
    return torch.stack(P, dim=0)

def spherical_harmonics(L, THETA, PHI):
    P = associated_legendre_polynomials(L, torch.cos(PHI))
    output =  [torch.zeros_like(THETA) for _ in range(L**2)]
    M2 =  [torch.zeros_like(THETA) for _ in range(2*(L-1)+1)]
    for m in range(L):
        if m > 0:
            M2[L-1+m] = torch.cos(m*THETA)
            M2[L-1-m] = torch.sin(m*THETA)
        else:
            M2[L-1]  = torch.ones_like(THETA)
    for l in range(L):
        for m in range(l+1):
            if m > 0:
                output[l**2 +l+m] = np.sqrt(2)*P[(l+1)*l//2+m]*np.sqrt((2*l+1)/(4*np.pi))*M2[L-1+m]
                output[l**2+ l-m] = np.sqrt(2)*P[(l+1)*l//2+m]*np.sqrt((2*l+1)/(4*np.pi))*M2[L-1-m]
            else:
                output[l**2 +l  ] = P[(l+1)*l//2]*np.sqrt((2*l+1)/(4*np.pi))*M2[L-1]
    return torch.stack(output, dim = 0)

def spherical_harmonics_scipy(L, THETA, PHI):
    output =  [0 for _ in range(L**2)]
    for l in range(L):
        for m in range(0,l+1):
            if m > 0:
                output[l**2 +l+m] = torch.tensor(np.sqrt(2)*sph_harm(m, l, THETA, PHI,).real, dtype=torch.float32)
                output[l**2+ l-m] = torch.tensor(np.sqrt(2)*sph_harm(m, l, THETA, PHI).imag, dtype=torch.float32)
            else:
                output[l**2 +l  ] = torch.tensor(sph_harm(0, l, THETA, PHI).real, dtype=torch.float32)
    # return tf.stack(output, axis = 0)
    output = torch.stack([torch.where(torch.abs(output_i) < 1.0e-7, torch.zeros_like(output_i), output_i) for output_i in output])
    output = output.type(torch.float32)
    return output

def legendre_polynomials(L, x):
    P = [torch.ones_like(x) for _ in range(L)]
    if L > 1:
        P[1] = x
    if L > 2:
        for l in range(2,L):
            P[l] = ((2*l-1)*x*P[l-1] - (l-1)*P[l-2])/l
    return torch.stack(P, axis = 0)


class InvLocalFeatOrientConvolution(nn.Module):
    def __init__(self, num_inputs, num_outputs, kernel_size, order, stride, padding, dilation_rate=1, bias=True):
        super(InvLocalFeatOrientConvolution, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        # kernel size should be odd number
        assert kernel_size % 2 == 1
        self.kernel_size = kernel_size
        self.order = order  
        self.stride = stride
        self.padding = padding
        self.dilation_rate = dilation_rate
        self.bias = bias
        self.hidden_size = (self.num_inputs+self.num_outputs)
        # self.leg_order = self.order+self.kernel_size
        t = (kernel_size-1)//2 + 1
        self.leg_order = t*(t+1)*(t+2)//6

        k_max = ((t-1)**2)*3


        self.orders_indices = []
        self.wigner_indices = []
        for l in range(self.order):     
            self.orders_indices+=(2*l+1)*[[l]]
            self.wigner_indices+=[[l,l**2+i%(2*l+1),l**2+i//(2*l+1)] for i in range((2*l+1)**2)]
        self.wi_size = len(self.wigner_indices)
        # self.orders_indices = tf.convert_to_tensor(self.orders_indices, dtype=tf.int32)
        # self.wigner_indices = tf.convert_to_tensor(self.wigner_indices, dtype=tf.int32)
        self.wigner_indices = np.array(self.wigner_indices)
        self.orders_indices = np.array(self.orders_indices)

        rad_list = []
        masks = []
        
        spherical_coords = np.zeros((kernel_size, kernel_size, kernel_size, 3), dtype = np.float32)
        for i in range(kernel_size):
            for j in range(kernel_size):
                for k in range(kernel_size): 
                    i1 = i - (kernel_size-1)//2
                    j1 = j - (kernel_size-1)//2
                    k1 = k - (kernel_size-1)//2
                    if i1 == 0 and j1 == 0 and k1 == 0:
                        spherical_coords[i,j,k,1] = 0.0
                        spherical_coords[i,j,k,2] = 0.0
                        if (kernel_size-1) == 0:
                            spherical_coords[i,j,k,0] = 0.0
                        else:
                            spherical_coords[i,j,k,0] = -1.0
                    else:
                        spherical_coords[i,j,k,0] = i1**2+j1**2+k1**2
                        spherical_coords[i,j,k,0] = (2*spherical_coords[i,j,k,0]/k_max - 1.0)
                        spherical_coords[i,j,k,1] = np.arctan2(np.sqrt(i1**2+j1**2), k1)
                        spherical_coords[i,j,k,2] = np.arctan2(j1, i1)
                    if spherical_coords[i,j,k,0] not in rad_list:
                        rad_list.append(spherical_coords[i,j,k,0])
                        m = len(rad_list)-1
                        masks.append(np.zeros((kernel_size, kernel_size, kernel_size), dtype = np.float32))
                    else:
                        m = rad_list.index(spherical_coords[i,j,k,0])
                    masks[m][i,j,k] = 1.0
        # self.spherical_coords = tf.convert_to_tensor(spherical_coords, dtype=tf.float32)
        self.leg_order = len(rad_list)
        self.sph_harm = spherical_harmonics_scipy(self.order, spherical_coords[...,2], spherical_coords[...,1])
        if self.order > 1:
            self.sph_harm[1:,(kernel_size-1)//2, (kernel_size-1)//2, (kernel_size-1)//2] = 0.0
        self.spherical_coords = torch.tensor(spherical_coords, dtype=torch.float32)
        # self.sph_harm = self.sph_harm/torch.sqrt(torch.sum(self.sph_harm**2, axis = [1,2,3], keepdims=True))
        self.sph_harm_gath = self.sph_harm[self.wigner_indices[:, 2]]
        self.masks = torch.tensor(np.stack(masks, axis = 0), dtype=torch.float32)
        self.basis_functions = torch.einsum('lijk, rijk->rlijk', self.sph_harm, self.masks)
        self.basis_functions = self.basis_functions.view(-1, 1, self.kernel_size, self.kernel_size, self.kernel_size)

        # self.kernel_weight = nn.Parameter(torch.randn(self.order**2, self.leg_order , self.num_inputs, self.hidden_size))
        self.kernel_weight = nn.Parameter(torch.randn(self.order**2, self.leg_order , self.num_inputs, self.num_outputs))
        # self.retyper_weight = nn.Parameter(torch.randn(self.hidden_size, self.hidden_size, self.num_outputs))

        
        if self.bias:
            self.param_bias = nn.Parameter(torch.zeros(self.num_outputs))

    

    def forward(self, input):
        in_depth, in_height, in_width = input.shape[2], input.shape[3], input.shape[4]
        self.basis_functions = self.basis_functions.to(self.kernel_weight.device)
        
        input_reshaped = input.reshape(-1, 1, input.shape[1], input.shape[2], input.shape[3])
        conv = F.conv3d(input_reshaped, self.basis_functions, stride=self.stride, padding=self.padding)
        conv_reshaped = conv.view(-1, self.num_inputs, self.leg_order, self.order**2, in_depth // self.stride, in_height // self.stride, in_width // self.stride)
        # output =torch.einsum('bdrlxyz, lrde->bexyz', conv_reshaped, self.kernel_weight)
        conv_reconstruct = torch.einsum('bdrlxyz, lrde->blexyz', conv_reshaped[:,:,:,self.wigner_indices[:,2]], self.kernel_weight[self.wigner_indices[:,1]])
        output = torch.sum(conv_reconstruct**2, axis = 1)
        # output = torch.einsum('bldxyz,blexyz,def->bfxyz', conv_reconstruct, conv_reconstruct,  self.retyper_weight)
        if self.bias:
            output += self.param_bias.view(1,-1,1,1,1)
        return output
    
    


